Skip to content

[Cuda] enable turboquant on gemma4#19891

Open
metascroy wants to merge 9 commits into
mainfrom
turbo-cuda
Open

[Cuda] enable turboquant on gemma4#19891
metascroy wants to merge 9 commits into
mainfrom
turbo-cuda

Conversation

@metascroy
Copy link
Copy Markdown
Contributor

No description provided.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 30, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19891

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 2 Pending

As of commit 0d0c1ba with merge base 5395f20 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 30, 2026
@metascroy
Copy link
Copy Markdown
Contributor Author

@claude review this code

@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@claude
Copy link
Copy Markdown

claude Bot commented May 30, 2026

Claude finished @metascroy's task in 1m 39s —— View job


Code Review: [Cuda] enable turboquant on gemma4

  • Read changed files and gather context
  • Review backends/cuda/triton/kernels/tq4_sdpa.py changes
  • Review examples/models/gemma4_31b/cuda_source_transformations.py (new file)
  • Review examples/models/gemma4_31b/export.py changes
  • Post review feedback

Summary

This PR adds TurboQuant TQ4 KV cache compression support for the CUDA backend on Gemma 4 31B, extending the existing MLX-only support. The approach monkey-patches the attention forward on full-attention layers to call the fused tq4_sdpa kernel with compressed KV cache. It also adds a scale parameter to tq4_sdpa to support Gemma 4's QK-norm (which uses scale=1.0 instead of the default 1/sqrt(D)).


Bug: CLI validation blocks CUDA + turboquant

examples/models/gemma4_31b/export.py:456

if args.turboquant and args.backend != "mlx":
    parser.error("--turboquant requires --backend mlx.")

This guard was not updated — it still rejects --turboquant unless the backend is mlx. The _export_cuda function now supports turboquant, but users can never reach that code path from CLI. This should be:

if args.turboquant and args.backend not in ("mlx", "cuda"):
    parser.error("--turboquant requires --backend mlx or --backend cuda.")

Or simply removed since both backends now support it.

Fix this →


Other observations

  1. tq4_sdpa.py — scale parameter addition (lines 643, 664-667, 679, 757-760): Clean addition. The default behavior is preserved (None1/sqrt(D)) and the explicit path is gated properly. The fake kernel signature also matches. LGTM.

  2. cuda_source_transformations.py — monkey-patching approach: The Qwen3 integration uses a flag (attn.turboquant = True) with an in-model branch, while this uses full forward replacement via types.MethodType. Both are reasonable — the monkey-patch is more self-contained since the Gemma 4 31B model doesn't have a turboquant codepath built in. One concern: if the model's forward signature or internal logic changes, the monkey-patch will silently diverge. That's a standard tradeoff with source transformations and acceptable here.

  3. cuda_source_transformations.py:118 — unused config variable:

    config = model.config

    This is assigned but never used (the loop accesses attributes directly from attn). It's harmless but unnecessary.

  4. cuda_source_transformations.py:69input_pos shape assumption: The code does torch.outer(input_pos.float(), self.inv_freq) which requires input_pos to be 1-D. The model's actual forward uses the same pattern, so this is fine — just noting the contract.

  5. No test coverage: There's no test for the new CUDA turboquant path. The Qwen3 integration also appears to lack unit tests for this path, so this may be intentional (tested via e2e export). Worth noting.


Verdict

The scale parameter addition to tq4_sdpa and the source transformation module look correct. The CLI validation bug at line 456 is a real blocker — users cannot actually use --turboquant --backend cuda from the command line despite the PR enabling the underlying functionality.


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant